package database

import (
	"gorm.io/gorm"
	"gorm.io/plugin/dbresolver"
	"time"
)

var policies = map[string]dbresolver.Policy{
	"random": dbresolver.RandomPolicy{},
}

type Configure interface {
	Init(*gorm.Config, func(string) gorm.Dialector) (*gorm.DB, error)
}

type ResolverConfigure interface {
	Init(*dbresolver.DBResolver, func(string) gorm.Dialector) *dbresolver.DBResolver
}

type DBConfig struct {
	dsn             string
	connMaxIdleTime int // 最大空闲时间,单位秒
	connMaxLifetime int // 存活时间,单位秒
	maxIdleConn     int
	maxOpenConn     int
	registers       []ResolverConfigure
}

// NewConfigure 初始化函数
func NewConfigure(
	dsn string,
	connMaxIdleTime int,
	connMaxLifetime int,
	maxIdleConn int,
	maxOpenConn int,
	registers []ResolverConfigure,
) Configure {
	return &DBConfig{
		dsn:             dsn,
		connMaxIdleTime: connMaxIdleTime,
		connMaxLifetime: connMaxLifetime,
		maxIdleConn:     maxIdleConn,
		maxOpenConn:     maxOpenConn,
		registers:       registers,
	}
}
func (d *DBConfig) Init(config *gorm.Config, open func(string) gorm.Dialector) (*gorm.DB, error) {
	db, err := gorm.Open(open(d.dsn), config)
	if err != nil {
		return nil, err
	}

	var register *dbresolver.DBResolver
	for i := range d.registers {
		register = d.registers[i].Init(register, open)
	}
	if register == nil {
		register = dbresolver.Register(dbresolver.Config{})
	}

	if d.connMaxIdleTime > 0 {
		register = register.SetConnMaxIdleTime(time.Duration(d.connMaxIdleTime) * time.Second)
	}

	if d.connMaxLifetime > 0 {
		register = register.SetConnMaxLifetime(time.Duration(d.connMaxLifetime) * time.Second)
	}

	if d.maxOpenConn > 0 {
		register = register.SetMaxOpenConns(d.maxOpenConn)
	}

	if d.maxIdleConn > 0 {
		register = register.SetMaxIdleConns(d.maxIdleConn)
	}

	if register != nil {
		err = db.Use(register)
	}
	return db, err
}

type DBResolverConfig struct {
	sources  []string
	replicas []string
	policy   string
	tables   []interface{}
}

// NewResolverConfigure 初始化函数
func NewResolverConfigure(sources []string, replicas []string, policy string, tables []string) ResolverConfigure {
	data := make([]interface{}, len(tables))
	for i := range tables {
		data[i] = tables[i]
	}
	return &DBResolverConfig{sources: sources, replicas: replicas, policy: policy, tables: data}
}

func (d *DBResolverConfig) Init(register *dbresolver.DBResolver, open func(string) gorm.Dialector) *dbresolver.DBResolver {
	if len(d.tables) == 0 && len(d.sources) == 0 && len(d.replicas) == 0 {
		return register
	}

	var config dbresolver.Config
	if len(d.sources) > 0 {
		config.Sources = make([]gorm.Dialector, len(d.sources))
		for i := range d.sources {
			config.Sources[i] = open(d.sources[i])
		}
	}

	if len(d.replicas) > 0 {
		config.Replicas = make([]gorm.Dialector, len(d.replicas))
		for i := range d.replicas {
			config.Replicas[i] = open(d.replicas[i])
		}
	}

	if d.policy != "" {
		policy, ok := policies[d.policy]
		if ok {
			config.Policy = policy
		}
	}

	if register == nil {
		register = dbresolver.Register(config, d.tables...)
		return register
	}
	register = register.Register(config, d.tables...)
	return register
}
